import pandas as pd
import re

# Define a function to parse a single line of the log file
def _parse_profile_log_line(line):
    pattern = (
        r'Step (?P<step>\d+) \| '
        r'Client (?P<client>\d+) \| '
        r'Sampling = (?P<sampling>[GL]) \| '
        r'Sampling Time = (?P<sampling_time>[\d.]+)ms \| '
        r'GPU Data Transfer Time = (?P<gpu_data_transfer_time>[\d.]+)ms \| '
        r'Compute Time = (?P<compute_time>[\d.]+)ms \| '
        r'Total Local Data Size = (?P<local_data_size>[\d.]+)MB \| '
        r'Total Remote Data Size = (?P<remote_data_size>[\d.]+)MB \| '
        r'Total Input Feature Size = (?P<input_feature_size>[\d.]+)MB \| '
        r'Server Comp Time = (?P<server_comp_time>[\d.]+)ms'
    )
    match = re.match(pattern, line)
    if match:
        return match.groupdict()
    return None

def _parse_profile_log_line_old(line):
    pattern = (
        r'Step (?P<step>\d+) \| '
        r'Client (?P<client>\d+) \| '
        r'Sampling = (?P<sampling>[GL]) \| '
        r'Sampling Time = (?P<sampling_time>[\d.]+)ms \| '
        r'GPU Data Transfer Time = (?P<gpu_data_transfer_time>[\d.]+)ms \| '
        r'Compute Time = (?P<compute_time>[\d.]+)ms \| '
        r'Total Local Data Size = (?P<local_data_size>[\d.]+)MB \| '
        r'Total Remote Data Size = (?P<remote_data_size>[\d.]+)MB \| '
        r'Total Input Feature Size = (?P<input_feature_size>[\d.]+)MB'
    )
    match = re.match(pattern, line)
    if match:
        return match.groupdict()
    return None

# Define a function to parse the entire log file
def _parse_profile_log_file(log_file_content):
    parsed_data = []
    for line in log_file_content:
        parsed_line = _parse_profile_log_line(line)
        if parsed_line:
            parsed_data.append(parsed_line)
        else:
            parsed_line = _parse_profile_log_line_old(line)
            if parsed_line:
                parsed_data.append(parsed_line)
    return pd.DataFrame(parsed_data)

# Define a function to print the profile log
def _process_profile_log(profile_df, is_llcg, bandwidth):
    # Local Training
    # Get the average time, data size, and compute time for each sampling type
    local_sampling = profile_df[profile_df['sampling'] == 'L']
    local_sampling_time = local_sampling['sampling_time'].astype(float).mean()
    local_gpu_data_transfer_time = local_sampling['gpu_data_transfer_time'].astype(float).mean()
    local_compute_time = local_sampling['compute_time'].astype(float).mean()
    local_local_data_size = local_sampling['local_data_size'].astype(float).mean()
    local_remote_data_size = local_sampling['remote_data_size'].astype(float).mean()
    try:
        local_server_comp_time = local_sampling['server_comp_time'].astype(float).mean()
    except:
        local_server_comp_time = 0

    # Global Training
    # Get the average time, data size, and compute time for each sampling type
    global_sampling = profile_df[profile_df['sampling'] == 'G']
    global_sampling_time = global_sampling['sampling_time'].astype(float).mean()
    global_gpu_data_transfer_time = global_sampling['gpu_data_transfer_time'].astype(float).mean()
    global_compute_time = global_sampling['compute_time'].astype(float).mean()          
    global_local_data_size = global_sampling['local_data_size'].astype(float).mean()
    global_remote_data_size = global_sampling['remote_data_size'].astype(float).mean()
    try:
        global_server_comp_time = global_sampling['server_comp_time'].astype(float).mean()
    except:
        global_server_comp_time = 0

    # Print Key Statistics
    # print("Key Statistics")
    local_training_time =  local_gpu_data_transfer_time + local_compute_time
    # print('Local Training Time: ', local_training_time)
    if is_llcg:
        local_data = local_local_data_size + local_remote_data_size
    else:
        local_data = 2 * local_local_data_size + local_remote_data_size
    print('Local Data Transfer Amont in MB: ', local_data)
    local_transfer_time = local_data * 1024 * 1024 * 8 / (bandwidth * 1024 * 1024 * 1024) * 1000 
    # print('Transferring using 1Gbps link will take (ms): ', local_transfer_time, 'ms')

    global_training_time = global_gpu_data_transfer_time + global_compute_time
    # print('Global Training Time: ', global_training_time)
    if is_llcg:
        global_data = global_local_data_size + global_remote_data_size
    else:
        global_data = 2 * global_local_data_size + global_remote_data_size
    print('Global Data Transfer Amont in MB: ', global_data)
    global_transfer_time = global_data * 1024 * 1024 * 8 / (bandwidth * 1024 * 1024 * 1024) * 1000 
    # print('Transferring using 1Gbps link will take (ms): ', global_transfer_time, 'ms')
    
    return local_training_time, local_transfer_time, global_training_time, global_transfer_time, local_sampling_time, global_sampling_time

# Function to analyze profile log
def analyze_profile_log(profile_file, is_llcg, bandwidth):
    with open(profile_file, 'r') as file:
        profile_content = file.readlines()

    df = _parse_profile_log_file(profile_content)
    local_training_time, local_transfer_time, global_training_time, global_transfer_time, local_sampling_time, global_sampling_time = _process_profile_log(df, is_llcg=is_llcg, bandwidth=bandwidth)

    return local_training_time, local_transfer_time, global_training_time, global_transfer_time, local_sampling_time, global_sampling_time

def analyze_detail_profile_log(profile_file, is_llcg, bandwidth):
    with open(profile_file, 'r') as file:
        profile_content = file.readlines()

    profile_df = _parse_profile_log_file(profile_content)
    # Local Training
    # Get the average time, data size, and compute time for each sampling type
    local_sampling = profile_df[profile_df['sampling'] == 'L']
    local_sampling_time = local_sampling['sampling_time'].astype(float).mean()
    local_gpu_data_transfer_time = local_sampling['gpu_data_transfer_time'].astype(float).mean()
    local_compute_time = local_sampling['compute_time'].astype(float).mean()
    local_local_data_size = local_sampling['local_data_size'].astype(float).mean()
    local_remote_data_size = local_sampling['remote_data_size'].astype(float).mean()
    local_server_comp_time = local_sampling['server_comp_time'].astype(float).mean()

    # Global Training
    # Get the average time, data size, and compute time for each sampling type
    global_sampling = profile_df[profile_df['sampling'] == 'G']
    global_sampling_time = global_sampling['sampling_time'].astype(float).mean()
    global_gpu_data_transfer_time = global_sampling['gpu_data_transfer_time'].astype(float).mean()
    global_compute_time = global_sampling['compute_time'].astype(float).mean()          
    global_local_data_size = global_sampling['local_data_size'].astype(float).mean()
    global_remote_data_size = global_sampling['remote_data_size'].astype(float).mean()
    global_server_comp_time = global_sampling['server_comp_time'].astype(float).mean()
    print('local_sampling_time: ', local_sampling_time)
    print('global_sampling_time: ', global_sampling_time)

    # Print Key Statistics
    # print("Key Statistics")
    local_training_time = local_sampling_time + local_gpu_data_transfer_time + local_compute_time - local_server_comp_time
    # print('Local Training Time: ', local_training_time)
    if is_llcg:
        local_data = local_local_data_size + local_remote_data_size
    else:
        local_data = 2 * local_local_data_size + local_remote_data_size
    print('Local Data Transfer Amont in MB: ', local_data)
    local_transfer_time = local_data * 1024 * 1024 * 8 / (bandwidth * 1024 * 1024 * 1024) * 1000
    # print('Transferring using 1Gbps link will take (ms): ', local_transfer_time, 'ms')

    global_training_time = global_sampling_time + global_gpu_data_transfer_time + global_compute_time - global_server_comp_time
    # print('Global Training Time: ', global_training_time)
    if is_llcg:
        global_data = global_local_data_size + global_remote_data_size
    else:
        global_data = 2 * global_local_data_size + global_remote_data_size
    print('Global Data Transfer Amont in MB: ', global_data)
    global_transfer_time = global_data * 1024 * 1024 * 8 / (bandwidth * 1024 * 1024 * 1024) * 1000
    # print('Transferring using 1Gbps link will take (ms): ', global_transfer_time, 'ms')

    return local_training_time, local_transfer_time, global_training_time, global_transfer_time, local_server_comp_time, global_server_comp_time

# Define a function to parse a single line of the validation log file
def _parse_validation_log_line(line):
    sync_pattern = (
        r'Step (?P<step>\d+) \| '
        r'Synchronization time = (?P<synchronization_time>[\d.]+)ms'
    )
    validation_pattern = (
        r'Step (?P<step>\d+) \| '
        r'Validation Accuracy = (?P<validation_accuracy>[\d.]+) \| '
        r'Validation time = (?P<validation_time>[\d.]+)ms'
    )
    extended_validation_pattern = (
        r'Step (?P<step>\d+) \| '
        r'Validation Accuracy = (?P<validation_accuracy>[\d.]+) \| '
        r'Validation time = (?P<validation_time>[\d.]+)ms \| '
        r'Sampling = (?P<sampling>[A-Z]) \| '
        r'Sampling interval = (?P<sampling_interval>\d+)'
    )
    
    sync_match = re.match(sync_pattern, line)
    extended_validation_match = re.match(extended_validation_pattern, line)
    validation_match = re.match(validation_pattern, line)
    
    if sync_match:
        return sync_match.groupdict(), 'sync'
    elif extended_validation_match:
        return extended_validation_match.groupdict(), 'extended_validation'
    elif validation_match:
        return validation_match.groupdict(), 'validation'
    return None, None

# Define a function to parse the entire validation log file
def _parse_validation_log_file(log_file_path):
    sync_data = []
    validation_data = []
    try:
        with open(log_file_path, 'r') as file:
            for line in file:
                parsed_line, line_type = _parse_validation_log_line(line)
                if parsed_line:
                    if line_type == 'sync':
                        sync_data.append(parsed_line)
                    elif line_type == 'validation' or 'extended_validdation':
                        validation_data.append(parsed_line)
        df_sync = pd.DataFrame(sync_data)
        df_validation = pd.DataFrame(validation_data)
        return df_sync, df_validation
    except Exception as e:
        print(f"Error reading file: {e}")
        return None, None

def analyze_train_log(train_log_file):
    df_sync, df_validation = _parse_validation_log_file(train_log_file)

    return df_sync, df_validation

if __name__ == '__main__':
    # Define the profile log file
    our_profile_file = '../output/profile_log/layer_2_ogbn-products_10_batchsize_256/profile.log'
    llcg_profile_file = '../output/profile_log/layer_2_ogbn-products_10_batchsize_256/profile_llcg.log'
    

    # Analyze the profile log
    our_local_train_time, our_local_comm_time, our_global_train_time, our_global_comm_time = analyze_profile_log(our_profile_file, is_llcg=False, bandwidth=1)
    llcg_local_train_time, llcg_local_comm_time, llcg_global_train_time, llcg_global_comm_time = analyze_profile_log(llcg_profile_file, is_llcg=True, bandwidth=1)

    dataset = 'ogbn-products'
    method = 'our'
    if method == 'our':
        log_dir = '../output/log'
    num_parts = 10
    global_interval = 5
    sampled_clients = 5
    valid_file = f'{log_dir}/layer_2_{dataset}_{num_parts}_{global_interval}_{sampled_clients}_256_0.001/validation.log'

    analyze_train_log(valid_file, dataset, method, num_parts, sampled_clients, global_interval,
                      our_local_train_time, our_local_comm_time, our_global_train_time, our_global_comm_time)

